#!/usr/bin/env python
# coding: utf-8

# # Prerequises

# In[10]:


print("Loading dependencies")
import numpy as np
import matplotlib.pyplot as plt
from mma import splx2bf, approx
from classif_helper import *
import gudhi as gd
from sklearn.neighbors import KernelDensity
from multiprocessing import cpu_count

# # Dataset generation

# In[11]:


def pt(low= 1, high=1.1, k=2, sigma=1):
    r = np.sqrt(np.random.uniform(low = low, high = high**2))
    θ = np.random.choice(range(k)) * 2*np.pi / k + np.random.normal(loc=0,scale=sigma)
    return r*np.cos(θ), r* np.sin(θ)
def orbit(n:int=100, r=0.5, x0=[])->list:
    point_list=[]
    if len(x0) != 2:
        x,y=np.random.uniform(size=2)
    else:
        x,y = x0
    point_list.append([x,y])
    for _ in range(n-1):
        x = (x + r*y*(1-y)) %1
        y = (y + r*x*(1-x)) %1
        point_list.append([x,y])
    return point_list
def get_pts(dataset:str="annulus", npts:int=100,  **kwargs)->np.ndarray:
    match dataset:
        case "annulus":
            return np.array([pt(**kwargs) for _ in range(npts)])
        case "orbit":
            return np.array(orbit(npts, **kwargs))
        case _:
            return np.array([])


# In[40]:


# Parameters
npts = 100_000
k = 4


# In[55]:


print("Generating dataset...")
X = get_pts(npts=npts, dataset="annulus", k=k, sigma = 0.2)
heatmap, xedges, yedges = np.histogram2d(X[:,0], X[:,1], bins=100, density = 1)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.imshow(heatmap.T, origin='lower')
plt.colorbar()
plt.savefig(f"images/annulus_heatmap_{npts}_pts_{k}_modes.png", dpi=200)
plt.clf()


# # Filtrations

# In[56]:


params = {
	"n_jobs":cpu_count(),
	"K":25_000,
	"kmin":50,"kmax":20_000,"nsamples":500,
	"precision":0.01,
	"dimension":1,"resolution":[200,200],
	"kde_bandwidth":0.3,
	"box":[[0,2],[1.25,3]],
	"kde_kernel": "gaussian",
	"normalize":0,
	"img_bandwidth":0.01,
	"ps":[0,0.5,1, np.inf],"threshold":2,
	"flatten":True,
}


# In[57]:


def get_bf(k, **kwargs):
    sample = X[1:k]
    alphacplx = gd.AlphaComplex(points=sample)
    # st = gd.RipsComplex(points=sample, max_edge_length = 0.25).create_simplex_tree()
    st = alphacplx.create_simplex_tree(max_alpha_square=np.sqrt(kwargs.get("threshold",4)))
    boundary, alpha_filtration = splx2bf(st)
    points = np.array([alphacplx.get_point(i) for i in range(st.num_vertices())])
    # points = sample
    kde = KernelDensity(kernel=kwargs.get("kde_kernel", 'gaussian'), bandwidth=kwargs.get("kde_bandwidth",0.5)).fit(sample)
    codensity_filtration = -np.array(kde.score_samples(points))
    # 	codensity_filtration -= np.min(codensity_filtration)
    return boundary, [alpha_filtration,codensity_filtration]


# In[63]:


print("Computing last image...")
last_img = compute_img(get_bf(params["K"]),plot=False,size=5,**params)
with open(f"images/last_img_synth_{params['K']}_pts.npy", 'wb') as f:
	np.save(f, last_img)

# # Computation

# In[7]:


print("Computing approximation images...")
iterator = np.linspace(start=params["kmin"], stop=params["kmax"], num=params["nsamples"], dtype=int)
# iterator = np.logspace(start=np.log10(params["kmin"]), stop=np.log10(params["kmax"]), num=params["nsamples"], dtype=int)
approximation_images = compute_imgs(iterator, get_bf, multithreads=True, **params)


# # Error plot

# In[ ]:


print("Saving errors")


# In[12]:


# Classical error
file_str = f"errors/L2mean_annulus_cv_{k}_modes_{params['K']}_npts"
nps = len(params["ps"])
nimgs = params["nsamples"]
errors = np.zeros(shape=(nimgs, nps))
npixels = np.prod(params["resolution"])
for i,img in enumerate(approximation_images):
	for j in range(nps):
		last = last_img[npixels*j:npixels*(j+1)]
		current = img[npixels*j:npixels*(j+1)]
		errors[i,j] = np.square(last-current).mean() # L^2 norm
with open(file_str+".npy", 'wb') as f:
    np.save(f, errors)
for i,p in enumerate(params["ps"]):
	plt.plot(iterator, errors[:,i], label=f"p={p}")
plt.xlabel("Number of points")
plt.ylabel("2-norm")
plt.legend()
plt.title(f"Error w.r.t. the image computed with {params['K']} points")
plt.savefig(file_str + ".svg")
plt.clf()


# In[15]:


# Scaled error
file_str = f"errors/ScaledL2mean_annulus_cv_{k}_modes_{params['K']}_npts"

nps = len(params["ps"])
nimgs = params["nsamples"]
errors = np.zeros(shape=(nimgs, nps))
npixels = np.prod(params["resolution"])
for i,img in enumerate(approximation_images):
	for j in range(nps):
		last = last_img[npixels*j:npixels*(j+1)]
		current = img[npixels*j:npixels*(j+1)]
		errors[i,j] = np.square(last-current).mean()/(last.max()) # normalized L^2 norm
with open(file_str+".npy", 'wb') as f:
    np.save(f, errors)
for i,p in enumerate(params["ps"]):
	plt.plot(iterator[1:], errors[1:,i], label=f"p={p}")
plt.xlabel("Number of points")
plt.ylabel("2-norm")
plt.legend()
plt.title(f"Scaled error w.r.t. the image computed with {params['K']} points")
plt.savefig(file_str + ".svg")

plt.clf()


# In[10]:


# Classical error
file_str = f"errors/loglogL2mean_annulus_cv_{k}_modes_{params['K']}_npts"

nps = len(params["ps"])
nimgs = params["nsamples"]
errors = np.zeros(shape=(nimgs, nps))
npixels = np.prod(params["resolution"])
for i,img in enumerate(approximation_images):
	for j in range(nps):
		last = last_img[npixels*j:npixels*(j+1)]
		current = img[npixels*j:npixels*(j+1)]
		errors[i,j] = np.square(last-current).mean() # L^2 norm
with open(file_str+".npy", 'wb') as f:
    np.save(f, errors)
for i,p in enumerate(params["ps"]):
	plt.loglog(iterator, errors[:,i], label=f"p={p}")
plt.xlabel("Number of points")
plt.ylabel("2-norm")
plt.legend()
plt.title(f"Error w.r.t. the image computed with {params['K']} points")
plt.savefig(file_str + ".svg")
plt.clf()


# In[11]:


# Scaled error
file_str = f"errors/loglogScaledL2mean_annulus_cv_{k}_modes_{params['K']}_npts"

nps = len(params["ps"])
nimgs = params["nsamples"]
errors = np.zeros(shape=(nimgs, nps))
npixels = np.prod(params["resolution"])
for i,img in enumerate(approximation_images):
	for j in range(nps):
		last = last_img[npixels*j:npixels*(j+1)]
		current = img[npixels*j:npixels*(j+1)]
		errors[i,j] = np.square(last-current).mean()/(last.max()) # normalized L^2 norm
with open(file_str+".npy", 'wb') as f:
    np.save(f, errors)
for i,p in enumerate(params["ps"]):
	plt.loglog(iterator, errors[:,i], label=f"p={p}")
plt.xlabel("Number of points")
plt.ylabel("2-norm")
plt.legend()
plt.title(f"Scaled error w.r.t. the image computed with {params['K']} points")
plt.savefig(file_str + ".svg")

plt.clf()


##### INF NORMS


# Classical error
file_str = f"errors/Linf_annulus_cv_{k}_modes_{params['K']}_npts"
nps = len(params["ps"])
nimgs = params["nsamples"]
errors = np.zeros(shape=(nimgs, nps))
npixels = np.prod(params["resolution"])
for i,img in enumerate(approximation_images):
	for j in range(nps):
		last = last_img[npixels*j:npixels*(j+1)]
		current = img[npixels*j:npixels*(j+1)]
		errors[i,j] = np.abs(last-current).max() # L^inf norm
with open(file_str+".npy", 'wb') as f:
    np.save(f, errors)
for i,p in enumerate(params["ps"]):
	plt.plot(iterator, errors[:,i], label=f"p={p}")
plt.xlabel("Number of points")
plt.ylabel("inf-norm")
plt.legend()
plt.title(f"Error w.r.t. the image computed with {params['K']} points")
plt.savefig(file_str + ".svg")
plt.clf()


# In[15]:


# Scaled error
file_str = f"errors/ScaledLinf_annulus_cv_{k}_modes_{params['K']}_npts"

nps = len(params["ps"])
nimgs = params["nsamples"]
errors = np.zeros(shape=(nimgs, nps))
npixels = np.prod(params["resolution"])
for i,img in enumerate(approximation_images):
	for j in range(nps):
		last = last_img[npixels*j:npixels*(j+1)]
		current = img[npixels*j:npixels*(j+1)]
		errors[i,j] = np.abs(last-current).max()/(last.max()) # normalized L^inf norm
with open(file_str+".npy", 'wb') as f:
    np.save(f, errors)
for i,p in enumerate(params["ps"]):
	plt.plot(iterator[1:], errors[1:,i], label=f"p={p}")
plt.xlabel("Number of points")
plt.ylabel("inf-norm")
plt.legend()
plt.title(f"Scaled error w.r.t. the image computed with {params['K']} points")
plt.savefig(file_str + ".svg")

plt.clf()


# In[10]:


# Classical error
file_str = f"errors/loglogLinf_annulus_cv_{k}_modes_{params['K']}_npts"

nps = len(params["ps"])
nimgs = params["nsamples"]
errors = np.zeros(shape=(nimgs, nps))
npixels = np.prod(params["resolution"])
for i,img in enumerate(approximation_images):
	for j in range(nps):
		last = last_img[npixels*j:npixels*(j+1)]
		current = img[npixels*j:npixels*(j+1)]
		errors[i,j] = np.abs(last-current).max() # L^inf norm
with open(file_str+".npy", 'wb') as f:
    np.save(f, errors)
for i,p in enumerate(params["ps"]):
	plt.loglog(iterator, errors[:,i], label=f"p={p}")
plt.xlabel("Number of points")
plt.ylabel("inf-norm")
plt.legend()
plt.title(f"Error w.r.t. the image computed with {params['K']} points")
plt.savefig(file_str + ".svg")
plt.clf()


# In[11]:


# Scaled error
file_str = f"errors/loglogScaledLinf_annulus_cv_{k}_modes_{params['K']}_npts"

nps = len(params["ps"])
nimgs = params["nsamples"]
errors = np.zeros(shape=(nimgs, nps))
npixels = np.prod(params["resolution"])
for i,img in enumerate(approximation_images):
	for j in range(nps):
		last = last_img[npixels*j:npixels*(j+1)]
		current = img[npixels*j:npixels*(j+1)]
		errors[i,j] = np.abs(last-current).max()/(last.max()) # normalized L^inf norm
with open(file_str+".npy", 'wb') as f:
    np.save(f, errors)
for i,p in enumerate(params["ps"]):
	plt.loglog(iterator, errors[:,i], label=f"p={p}")
plt.xlabel("Number of points")
plt.ylabel("2-norm")
plt.legend()
plt.title(f"Scaled error w.r.t. the image computed with {params['K']} points")
plt.savefig(file_str + ".svg")

plt.clf()



